/*************************************************************************
 * The contents of this file are subject to the MYRICOM MYRINET          *
 * EXPRESS (MX) NETWORKING SOFTWARE AND DOCUMENTATION LICENSE (the       *
 * "License"); User may not use this file except in compliance with the  *
 * License.  The full text of the License can found in LICENSE.TXT       *
 *                                                                       *
 * Software distributed under the License is distributed on an "AS IS"   *
 * basis, WITHOUT WARRANTY OF ANY KIND, either express or implied.  See  *
 * the License for the specific language governing rights and            *
 * limitations under the License.                                        *
 *                                                                       *
 * Copyright 2003 - 2004 by Myricom, Inc.  All rights reserved.          *
 *************************************************************************/

#include "mx_auto_config.h"
#include "myriexpress.h"
#include "mx__lib_types.h"
#  include "mx__partner.h"
#include "mx__endpoint.h"
#include "mx__shmem.h"

#if MX_USE_SHMEM

#include "mx__shim.h"
#include "mx_byteswap.h"
#include "mx__cmpswap.h"
#include "mx__segment.h"
#include "mx_stbar.h"
#include "mx_util.h"
#include "mx__lib.h"
#include "mx__request.h"
#include "mx__driver_interface.h"
#include "mx__valgrind.h"

#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
#include <string.h>
#include <errno.h>
#include <signal.h>
#include <unistd.h>

#define MX__SHM_REQ_MAGIC UINT64_C(0x62584a101b071341)

static uint32_t
mx__shm_filename(uint32_t board_num, int data, uint16_t owner, uint16_t peer, char *fname)
{
  if (data) {
    sprintf(fname, "/var/tmp/mxshmem-data-u%d-b%d-e%dx%d", (int)getuid(), board_num, owner, peer);
  } else {
    sprintf(fname, "/var/tmp/mxshmem-ctrl-u%d-b%d-e%d", (int)getuid(), board_num, owner);
  }
  return data ? MX__SHM_FIFO_FILESIZE : MX__SHM_REQ_FILESIZE;
}

void *
mx__shm_open(struct mx_endpoint *ep, uint16_t peer, int create, int data)
{
  char fname[100];
  void *ptr;
  int fd;
  uint32_t len;

  mx_always_assert(sizeof(struct mx__shmreq) == 512);
  mx_always_assert(sizeof(struct mx__shm_queue) == sizeof(struct mx__shmreq) + 128 * 2);
  len = mx__shm_filename(ep->board_num, data,
		   create ? ep->myself->eid : peer,
		   create ? peer : ep->myself->eid,
		   fname);
  len = (len + sysconf(_SC_PAGESIZE) - 1) & ~(sysconf(_SC_PAGESIZE) - 1);
  if (create)
    unlink(fname);
  fd = open(fname,O_RDWR | (create ? O_TRUNC | O_EXCL | O_CREAT : 0),0666);
  if (fd < 0) {
    return 0;
  }
  if (create)
    ftruncate(fd, len);
  ptr = mmap(0, len, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
  close(fd);
  mx_fixme_assert(ptr != MAP_FAILED);
  if (create) {
    memset(ptr,0, len);
  } else if (!data) {
    struct mx__shm_queue *queue = ptr;
    if (!queue->pid || kill(queue->pid, 0) != 0) {
      munmap(ptr,len);
      ptr = 0;
    }
  }
  return ptr;
}

void
mx__shm_close(struct mx_endpoint *ep, void *ptr, uint16_t peer, int creator, int data)
{
  uint32_t len;
  char fname[100];

  len = mx__shm_filename(ep->board_num, data,
		   creator ? ep->myself->eid : peer,
		   creator ? peer : ep->myself->eid,
		   fname);
  if (creator)
    unlink(fname);
  if (!ptr)
    return;
  munmap(ptr, len);
}

struct mx__bounce_req {
  struct mx__shmreq req;
  struct mx__shm_peer *peer;
  STAILQ_ENTRY(mx__bounce_req) next;
};

void
mx__shm_forget_peer(struct mx_endpoint *ep, uint16_t endpt)
{
  struct mx__shm_peer *peer = ep->shm->peers + endpt;
  mx__shm_close(ep, peer->snd_shmq, endpt, 0, 0);
  peer->snd_shmq = 0;
  mx__shm_close(ep, peer->shm_rcv_fifo, endpt, 0, 1);
  peer->shm_rcv_fifo = 0;
}

static struct mx__shmreq *
mx__bounce_slot(struct mx__shm_info *shm, struct mx__shm_peer *peer)
{
  struct mx__bounce_req *bounce;
  bounce = mx_malloc(sizeof(*bounce));
  mx_fixme_assert(bounce);
  bounce->peer = peer;
  STAILQ_INSERT_TAIL(&shm->bounce_reqq,bounce,next);
  return &bounce->req;
}

static inline struct mx__shmreq *
mx__shm_queue_slot(struct mx__shm_info *shm, struct mx__shm_peer *peer, int bounce, uint8_t *valid_bit)
{
  uint32_t old_widx, new_widx;
  struct mx__shm_queue *shmq;
  shmq = peer->snd_shmq;
  do {
    old_widx = shmq->write_idx;
    new_widx = old_widx + 1;
    if(new_widx - shmq->read_idx >= MX__SHM_REQ_CNT) {
      *valid_bit = 0;
      return bounce ? mx__bounce_slot(shm, peer) : NULL;
    }
  } while (mx__cmpswap_u32(&shmq->write_idx, old_widx, new_widx) != old_widx);
  *valid_bit = old_widx & MX__SHM_REQ_CNT;
  return shmq->queue + MX__SHM_REQ_SLOT(old_widx);
}

static inline struct mx__shmreq *
mx__shm_slot(struct mx__shm_info *shm, struct mx__shm_peer *peer, uint8_t *valid_bit)
{
  if (!STAILQ_EMPTY(&shm->bounce_reqq)) {
    *valid_bit = 0;
    return mx__bounce_slot(shm, peer);
  }
  return mx__shm_queue_slot(shm, peer, 1, valid_bit);
}

static inline void
mx__shm_wake_if_needed(mx_endpoint_t ep, struct mx__shm_peer *peer)
{
  MX_READBAR();
  if (peer->snd_shmq->waiting) {
    mx_wake_endpt_t xwake;
    xwake.endpt = peer - ep->shm->peers;
    mx__wake_endpoint(ep->handle, &xwake);
  }
}

static void
mx__process_bounces(mx_endpoint_t ep)
{
  struct mx__shmreq *slot;
  struct mx__bounce_req *bounce;
  struct mx__shm_info *shm;
  struct mx__shm_peer *peer;

  shm = ep->shm;
  while (!STAILQ_EMPTY(&shm->bounce_reqq)) {
    uint8_t valid_bit;
    bounce = STAILQ_FIRST(&shm->bounce_reqq);
    peer = bounce->peer;
    if (!peer->snd_shmq) {
      peer->snd_shmq = mx__shm_open(ep, peer - shm->peers, 0, 0);
      mx_fixme_assert(peer->snd_shmq);
    }
    slot = mx__shm_queue_slot(shm, peer, 0, &valid_bit);
    if (!slot) {
      break;
    }
    STAILQ_REMOVE_HEAD(&shm->bounce_reqq, next);
    memcpy(slot,&bounce->req,offsetof(struct mx__shmreq, type));
    MX_WRITEBAR();
    slot->type = bounce->req.type + valid_bit;
    bounce->peer = 0;
    mx_free(bounce);
    mx__shm_wake_if_needed(ep, peer);
  }
}


static inline void *
mx__shm_fifo_send(struct mx_endpoint *ep, struct mx__shm_peer *peer, uint32_t length)
{
  struct mx__shm_fifo *fifo;
  void *ptr;
  if (!peer->shm_snd_fifo) {
    peer->shm_snd_fifo = mx__shm_open(ep, peer - ep->shm->peers, 1, 1);
    mx_fixme_assert(peer->shm_snd_fifo);
  }
  fifo = peer->shm_snd_fifo;
  if (fifo->sent - fifo->rcvd + length > MX__SHM_FIFO_LENGTH)
    return NULL;
  ptr = fifo->data + fifo->send_idx;
  fifo->send_idx += length;
  if (fifo->send_idx >= MX__SHM_FIFO_LENGTH)
    fifo->send_idx = 0;
  fifo->sent += length;
  return ptr;
}

static inline void *
mx__shm_fifo_recv(struct mx_endpoint *ep, struct mx__shm_peer *peer, uint32_t length)
{
  struct mx__shm_fifo *fifo;
  void *ptr;
  if (!peer->shm_rcv_fifo) {
    peer->shm_rcv_fifo = mx__shm_open(ep, peer - ep->shm->peers, 0, 1);
    mx_fixme_assert(peer->shm_rcv_fifo);
  }
  fifo = peer->shm_rcv_fifo;
  ptr = fifo->data + fifo->recv_idx;
  fifo->recv_idx += length;
  if (fifo->recv_idx >= MX__SHM_FIFO_LENGTH)
    fifo->recv_idx = 0;
  return ptr;
}

static inline void
mx__shm_fifo_free(struct mx_endpoint *ep, struct mx__shm_peer *peer, uint32_t length)
{
  MX_READBAR();
  peer->shm_rcv_fifo->rcvd += length;
}


void mx__shm_send(struct mx_endpoint *ep, union mx_request *q)
{
  struct mx__shm_peer *peer;
  struct mx__shmreq *shm_req;
  uint32_t type;
  int is_large;
  void *data;
  uint8_t valid_bit;
  struct mx__partner * partner = q->basic.partner;

  peer = ep->shm->peers + partner->eid;
  if (!peer->snd_shmq) {
    peer->snd_shmq = mx__shm_open(ep, partner->eid, 0, 0);
    mx_fixme_assert(peer->snd_shmq);
  }
  shm_req = mx__shm_slot(ep->shm, peer, &valid_bit);
  is_large = q->send.basic.type == MX__REQUEST_TYPE_SEND_LARGE;
  if (!is_large && q->send.basic.status.msg_length <= MX__SHM_SMALLSIZE) {
    type = MX__SHM_REQ_SMALL;
    if (q->send.basic.status.msg_length)
      mx__copy_from_segments(shm_req->buf, q->send.segments, q->send.count,
			     q->send.memory_context, 0, q->send.basic.status.msg_length);
  } else if (!is_large && q->send.basic.status.msg_length <=  MX__SHM_MEDIUMSIZE &&
	(data = mx__shm_fifo_send(ep, peer, q->send.basic.status.msg_length))) {
    type = MX__SHM_REQ_MEDIUM;
    mx__copy_from_segments(data, q->send.segments, q->send.count,
			   q->send.memory_context, 0, q->send.basic.status.msg_length);
  } else {
    /* in case we are upgrading */
    q->send.basic.type = MX__REQUEST_TYPE_SEND_LARGE;
    is_large = 1;

    type = MX__SHM_REQ_LARGE;
    shm_req->req_ptr = (uintptr_t)q;
    shm_req->src_session = ntohl(ep->endpoint_sid_n);
    if (q->send.count > 1) {
      int i;
      mx_shm_seg_t * segs = mx_malloc (q->send.count * sizeof(*segs));
      mx_fixme_assert(segs);
      for(i=0; i < q->send.count; i++) {
	segs[i].vaddr = (uintptr_t)q->send.segments[i].segment_ptr;
	segs[i].len = q->send.segments[i].segment_length;
      }
      shm_req->src_segs.vaddr = (uintptr_t) segs;
    } else {
      shm_req->src_segs.vaddr = (uintptr_t)q->send.segments[0].segment_ptr;
      shm_req->src_segs.len = q->send.segments[0].segment_length;
    }
    shm_req->src_nsegs = q->send.count;
    q->send.shm_magic = (uintptr_t)q ^ MX__SHM_REQ_MAGIC;
  }
  shm_req->length = q->send.basic.status.msg_length;
  shm_req->match_info = q->send.basic.status.match_info;
  shm_req->peer_endpt = ep->myself->eid;
  MX_WRITEBAR();
  shm_req->type = type + valid_bit;
  MX_WRITEBAR();
  mx__shm_wake_if_needed(ep, peer);
  q->send.basic.status.xfer_length = q->send.basic.status.msg_length;
  q->send.basic.state = MX__REQUEST_STATE_PENDING;
  if (!is_large) {
    mx__send_complete(ep, q, MX_STATUS_SUCCESS);
  } else {
    ep->sendshm_count += 1;
  }
}

void mx__shm_large_ack(struct mx__shm_info *shm, uint16_t peer_endpt, uint64_t req,
		       uint32_t length, uint64_t src_segs, uint32_t src_nsegs)
{
  struct mx__shm_peer *peer;
  struct mx__shmreq *shm_req;
  uint8_t valid_bit;

  peer = shm->peers + peer_endpt;
  shm_req = mx__shm_slot(shm, peer, &valid_bit);
  shm_req->req_ptr = req;
  shm_req->length = length;
  shm_req->src_segs.vaddr = (uintptr_t) src_segs;
  shm_req->src_nsegs = src_nsegs;
  MX_WRITEBAR();
  shm_req->type = MX__SHM_REQ_ACK + valid_bit;

}

void
mx__shm_copy(struct mx_endpoint *ep, uint8_t peer_endpt, uint64_t peer_req,
	     mx_shm_seg_t *src_segs, uint32_t src_nsegs, uint32_t src_session,
	     union mx_request *recv)
{
  int rc;
  mx_direct_getv_t xget;

  if (recv->recv.count > 1) {
    int i;
    mx_shm_seg_t * segs = mx_malloc (recv->recv.count * sizeof(*segs));
    mx_fixme_assert(segs);
    for(i=0; i < recv->recv.count; i++) {
      segs[i].vaddr = (uintptr_t) recv->recv.segments[i].segment_ptr;
      segs[i].len = recv->recv.segments[i].segment_length;
    }
    xget.dst_segs.vaddr = (uintptr_t) segs;
  } else {
    xget.dst_segs.vaddr = (uintptr_t)recv->recv.segments[0].segment_ptr;
    xget.dst_segs.len = recv->recv.segments[0].segment_length;
  }
  xget.dst_nsegs = recv->recv.count;

  xget.src_segs = *src_segs;
  xget.src_nsegs = src_nsegs;

  xget.length = recv->recv.basic.status.xfer_length;
  xget.src_endpt = peer_endpt;
  xget.src_board_num = ep->board_num;
  xget.src_session = src_session;
  rc = mx__direct_getv(ep->handle, &xget);
  if (rc != MX_SUCCESS) {
    /* if the session is wrong, we get EPERM.
     * We do not return MX_STATUS_BAD_SESSION to the sender then
     * since this case is a bug in the lib. */
    if (errno == EIO) {
      mx_fatal("mx__direct_get failed, check kernel logs for error messages");
    } else {
      mx_printf("mx__direct_get src=%p:%d,dst=%p:%d,length=%u\n"
		"\tep%d <- ep%d:"
		" errno=%d:%s\n",
		(void*)(uintptr_t)xget.src_segs.vaddr, (int) xget.src_nsegs,
		(void*)(uintptr_t)xget.dst_segs.vaddr, (int) xget.dst_nsegs,
		xget.length,
		ep->myself->eid, xget.src_endpt,
		errno, strerror(errno));
      mx_fatal("mx__direct_get failed");
    }
  }
  if (recv->recv.count > 1)
    mx_free ((void*)(uintptr_t) xget.dst_segs.vaddr);
  if (!ep->shm->peers[peer_endpt].snd_shmq) {
    ep->shm->peers[peer_endpt].snd_shmq = mx__shm_open(ep, peer_endpt, 0, 0);
    mx_fixme_assert( ep->shm->peers[peer_endpt].snd_shmq);
  }
  /* pass the list fo segment to the sender if it was malloc'ed so that may free it. */
  mx__shm_large_ack(ep->shm, peer_endpt, peer_req, xget.length,
	src_nsegs > 1 ? src_segs[0].vaddr : 0, src_nsegs);
  mx__shm_wake_if_needed(ep, ep->shm->peers + peer_endpt);
}

static void inline
mx__shm_recv(mx_endpoint_t ep, struct mx__shmreq *sreq, void *data)
{
  uint64_t match_info = sreq->match_info;
  uint32_t ctxid = CTXID_FROM_MATCHING(ep, match_info);
  struct mx__partner * partner;
  union mx_request *rreq;
  uint8_t type;
  uint32_t discard;

  partner = mx__endpoint_lookup_partner(ep, sreq->peer_endpt, ntohs(ep->myself->peer_index_n));
  type = MX__SHM_REQ_SLOT(sreq->type);
  rreq = mx__endpoint_match_receive(ep, partner, match_info, sreq->length, data, &discard);
  if (discard) {
    /* this message has processed by the unexpected handler, do nothing else */
    return;
  } else if (rreq) {
    /* rreq matched this message */
    mx_assert(rreq->recv.unexpected == 0);
    mx__spliceout_request(&ep->ctxid[ctxid].recv_reqq, rreq);
    rreq->recv.basic.state |= MX__REQUEST_STATE_RECV_MATCHED;
    rreq->recv.basic.status.match_info = match_info;
    rreq->recv.basic.status.msg_length = sreq->length;
    rreq->recv.basic.status.xfer_length =
      MX_MIN(rreq->recv.r_length, rreq->recv.basic.status.msg_length);
    if (type != MX__SHM_REQ_LARGE) {
      mx__copy_to_segments(rreq->recv.segments, rreq->recv.count,
			   rreq->recv.memory_context, 0, data,
			   rreq->recv.basic.status.xfer_length);
    } else {
      mx__shm_copy(ep, sreq->peer_endpt, sreq->req_ptr, &sreq->src_segs, sreq->src_nsegs,
		   sreq->src_session, rreq);
    }
    mx__recv_complete(ep, rreq, MX_STATUS_SUCCESS);
  } else {
    /* no receive matched, store the message as an unexpected */
    rreq = mx__rl_alloc(ep);
    if (rreq == NULL) {
      mx_fatal("mx__shmem_luigi:out of resources");
    }
    rreq->recv.segments = &rreq->recv.segment;
    rreq->recv.count = 1;
    rreq->recv.segment.segment_length = 0;
    rreq->recv.segment.segment_ptr = 0;
    rreq->recv.segments = &rreq->recv.segment;
    rreq->recv.basic.status.match_info = match_info;
    rreq->recv.basic.status.msg_length = sreq->length;
    rreq->recv.basic.status.xfer_length = sreq->length;
    rreq->recv.basic.wq = NULL;
    rreq->recv.unexpected = 1;
    rreq->recv.put_target = 0;
    mx__enqueue_request(&ep->ctxid[ctxid].unexpq, rreq);
    if (type == MX__SHM_REQ_LARGE) {
      rreq->recv.basic.type = MX__REQUEST_TYPE_RECV_SHM;
      rreq->recv.basic.state = MX__REQUEST_STATE_PENDING;
      rreq->recv.shm_peer_req = sreq->req_ptr;
      rreq->recv.shm_src_segs = sreq->src_segs;
      rreq->recv.shm_src_nsegs = sreq->src_nsegs;
      rreq->recv.shm_src_session = sreq->src_session;
      rreq->recv.shm_peer_endpt = sreq->peer_endpt;
      rreq->recv.segment.segment_ptr = NULL;
    } else if (sreq->length) {
      /* handle it like a unexpected medium from the network */
      void *segment_ptr = mx_malloc(sreq->length);
      rreq->recv.basic.type = MX__REQUEST_TYPE_RECV;
      rreq->recv.basic.state = MX__REQUEST_STATE_COMPLETED;
      if (segment_ptr == NULL) {
	mx_fatal("Warning: mx__self_send/unexp:mx_malloc failed");
      }
      rreq->recv.ordered_unexp_weight = 0;
      ep->unexp_queue_length += sreq->length;
      rreq->recv.segment.segment_ptr = MX_VA_TO_SEGMENT_PTR(segment_ptr);
      rreq->recv.segment.segment_length = sreq->length;
      rreq->recv.count = 1;
      mx_memcpy(segment_ptr, data, sreq->length);
    } else {
      rreq->recv.basic.type = MX__REQUEST_TYPE_RECV;
      rreq->recv.basic.state = MX__REQUEST_STATE_COMPLETED;
      rreq->recv.segment.segment_ptr = NULL;
      rreq->recv.segment.segment_length = 0;
      rreq->recv.count = 1;
    }
  }
  mx__partner_to_addr(partner, &rreq->basic.status.source);
  rreq->basic.partner = partner;
}

void mx__shmem_luigi(mx_endpoint_t ep)
{
  union mx_request *acked_req;
  struct mx__shmreq *sreq;
  struct mx__shm_queue *shmq;
  struct mx__shm_peer *peer;
  uint8_t type;
  void *data;
  shmq = ep->shm->shmq;
  sreq  = shmq->queue + MX__SHM_REQ_SLOT(shmq->read_idx);
  type = sreq->type;
  if ((type ^ shmq->read_idx) & MX__SHM_REQ_CNT) {
    mx__process_bounces(ep);
    return;
  }
  MX_READBAR();
  MX_VALGRIND_MEMORY_MAKE_READABLE(sreq, sizeof(*sreq));
  switch (MX__SHM_REQ_SLOT(type)) {
  case MX__SHM_REQ_SMALL:
    mx__shm_recv(ep, sreq, sreq->buf);
    shmq->read_idx += 1;
    break;
  case MX__SHM_REQ_LARGE:
    mx__shm_recv(ep, sreq, NULL);
    shmq->read_idx += 1;
    break;
  case MX__SHM_REQ_MEDIUM:
    peer = ep->shm->peers + sreq->peer_endpt;
    data = mx__shm_fifo_recv(ep, peer, sreq->length);
    mx__shm_recv(ep, sreq, data);
    mx__shm_fifo_free(ep, peer, sreq->length);
    shmq->read_idx += 1;
    break;
  case MX__SHM_REQ_ACK:
    acked_req = (union mx_request *)(uintptr_t)sreq->req_ptr;
    mx_assert(acked_req->send.shm_magic == (sreq->req_ptr ^ MX__SHM_REQ_MAGIC));
    if (sreq->src_nsegs > 1)
      mx_free ((void*)(uintptr_t) sreq->src_segs.vaddr);
    acked_req->send.shm_magic = 0;
    acked_req->send.basic.status.xfer_length = sreq->length;
    mx__send_complete(ep, acked_req, MX_STATUS_SUCCESS);
    ep->sendshm_count -= 1;
    shmq->read_idx += 1;
    break;
  default:
    mx_fatal("unknown type");
  }
  MX_VALGRIND_MEMORY_MAKE_NOACCESS(sreq, sizeof(sreq));
  MX_VALGRIND_MEMORY_MAKE_READABLE(&sreq->type, sizeof(sreq->type));
  mx__process_bounces(ep);
}

#endif /* MX_USE_SHMEM */
